from datetime import datetime
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader
import torch

# 加载预训练的 tokenizer 和模型
tokenizer = BertTokenizer.from_pretrained('model/chinese-macbert-base')
model = BertForSequenceClassification.from_pretrained('model/chinese-macbert-base', num_labels=10)

# 准备训练数据
texts = ['这是一个正样本', '这是一个负样本', '这是一个错误样本']
labels = [0, 1, 2]

# 编码文本数据
encoded_inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
input_ids = encoded_inputs['input_ids']
attention_mask = encoded_inputs['attention_mask']
labels = torch.tensor(labels)

# 创建数据集和数据加载器
dataset = TensorDataset(input_ids, attention_mask, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 定义优化器和损失函数
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
loss_fn = torch.nn.CrossEntropyLoss()

# 训练模型
model.train()
for epoch in range(5):
    total_loss = 0
    for step, batch in enumerate(dataloader):
        batch_input_ids = batch[0]
        batch_attention_mask = batch[1]
        batch_labels = batch[2]

        optimizer.zero_grad()
        outputs = model(batch_input_ids, attention_mask=batch_attention_mask, labels=batch_labels)
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch + 1}, average loss: {avg_loss:.4f}')

# 保存模型
# 获取当前时间戳，精确到秒
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
model_path = "train/trained_model" + timestamp + ".pt"
print(model_path);
torch.save(model.state_dict(), model_path)

# 加载训练好的模型
model = BertForSequenceClassification.from_pretrained('model/chinese-macbert-base', num_labels=10)
model.load_state_dict(torch.load(model_path))

# 准备验证数据
text = '这是一个错误样本'
encoded = tokenizer.encode_plus(
    text,
    add_special_tokens=True,
    max_length=128,
    padding=True,
    truncation=True,
    return_tensors='pt'
)

input_ids = encoded['input_ids']
attention_mask = encoded['attention_mask']

# 模型前向传播
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits

# 获取预测标签
predicted_label = torch.argmax(logits, dim=1).item()

print(f"预测标签: {predicted_label}")